#!/usr/bin/env python

"""
This script normalizes for sampling depth based on the median-ratio method as performed in DESeq2. But 
unlike DESeq2, we don't care here if there are floats in there. I use this for normalizing metagenomic 
coverage data, like gene-level coverage, or summed KO coverages. These are normalized for gene-length already
because they are "coverages", but they are not yet normalized for sampling depth – which is where this script 
comes in. 

I also found myself wanting this because I wanted to do differential abundance testing of coverages
of KO terms. DESeq2 doesn't require normalizing for gene-length because it is the same unit being analyzed
across all samples – the same gene, so the same size. However, after grouping genes into their KO annotations,
they no longer all represent the same units across all samples. It is because of this I decided to stick with 
gene-level coverages (which are normalized for gene-length), and then sum those values based on KO annotations.
I wanted to have access to this method of normalization in addition to the coverage per million method I usually
employ.

Normalization was initially described in this paper (http://dx.doi.org/10.1186/gb-2010-11-10-r106; 
e.q. 5), and this site is super-informative in general about the DESeq2 process, and helped me 
understand the process better to implement it: 
https://hbctraining.github.io/DGE_workshop/lessons/02_DGE_count_normalization.html
"""

import os
import sys
import argparse
import pandas as pd
import numpy as np
from scipy.stats.mstats import gmean

parser = argparse.ArgumentParser(description="This script normalizes a table for sampling depth based on the \
                                              median-ratio method as performed in DESeq2. See note at top of \
                                              script for more info. It expects a tab-delimited table with \
                                              samples as columns and genes/units as rows. For version \
                                              info, run `bit-version`.")

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input-table", help="Input tab-delimited table", action="store", required=True)
parser.add_argument("-o", "--output-table", help='Output filename (default: "MR-normalized.tsv")', action="store", default="MR-normalized.tsv")

if len(sys.argv)==1:
    parser.print_help(sys.stderr)
    sys.exit(0)

args = parser.parse_args()

################################################################################

tab = pd.read_csv(args.input_table, sep = "\t", index_col = 0)

## calculating size factors
# getting geometric means for each row
with np.errstate(divide = 'ignore'):
    geomeans = gmean(tab, axis = 1)

# getting ratios of gene values to geometric means
ratios_tab = (tab.T / geomeans).T

sizeFactors = ratios_tab[geomeans > 0].median().to_list()

# dividing by size factors
norm_tab = tab / sizeFactors

# writing out normalized table
norm_tab.to_csv(args.output_table, sep = "\t")
